{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "daa4280a-d000-46e2-a92e-f243f080c704",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d03b708904d4b1ba03c2679c38a87bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "from sentence_transformers import SentenceTransformer\n",
    "sentences_1 = [\"样例数据-1\", \"样例数据-2\"]\n",
    "model = SentenceTransformer('BAAI/bge-small-zh-v1.5')\n",
    "embeddings_1 = model.encode(sentences_1, normalize_embeddings=True, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b10d3b0a-b95e-4b7d-a652-e3ea28b1cead",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'早上好🌞！新的一天开始了，祝您心情愉快，工作顺利！有什么可以帮助您的吗？'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import time\n",
    "import jwt\n",
    "import requests\n",
    "\n",
    "# 实际KEY，过期时间\n",
    "def generate_token(apikey: str, exp_seconds: int):\n",
    "    try:\n",
    "        id, secret = apikey.split(\".\")\n",
    "    except Exception as e:\n",
    "        raise Exception(\"invalid apikey\", e)\n",
    "\n",
    "    payload = {\n",
    "        \"api_key\": id,\n",
    "        \"exp\": int(round(time.time() * 1000)) + exp_seconds * 1000,\n",
    "        \"timestamp\": int(round(time.time() * 1000)),\n",
    "    }\n",
    "    return jwt.encode(\n",
    "        payload,\n",
    "        secret,\n",
    "        algorithm=\"HS256\",\n",
    "        headers={\"alg\": \"HS256\", \"sign_type\": \"SIGN\"},\n",
    "    )\n",
    "\n",
    "def glm4air(prompt):\n",
    "    url = \"https://open.bigmodel.cn/api/paas/v4/chat/completions\"\n",
    "    headers = {\n",
    "      'Content-Type': 'application/json',\n",
    "      'Authorization': generate_token(\"填入你的 key\", 1000)\n",
    "    }\n",
    "    \n",
    "    data = {\n",
    "        \"model\": \"glm-4-air\",\n",
    "        \"messages\": [{\"role\": \"user\", \"content\":  prompt }]\n",
    "    }\n",
    "    \n",
    "    response = requests.post(url, headers=headers, json=data, timeout=200)\n",
    "    return response.json()['choices'][0]['message']['content']\n",
    "\n",
    "glm4air('早上好')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f3ad19be-23cd-4104-addb-984d58403ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = open('corpus.txt').readlines()\n",
    "query = pd.read_csv('test_question.csv')['question'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "155573a8-5f21-4e8e-bb41-775ab2b77e64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d9b7aaa3310478faa3a5e0da4841792",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d23e4e979c024f65be939a5e1c3b0147",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1563 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "query_feat = model.encode(query, normalize_embeddings=True, show_progress_bar=True)\n",
    "corpus_feat = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "dabe67cb-0258-4ec6-8b3d-7d395471908e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "332it [22:10,  4.01s/it]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "test_answer = []\n",
    "for q, qtext in tqdm(zip(query_feat, query)):\n",
    "    top1_content = corpus[np.dot(q, corpus_feat.T).argsort()[-1]]\n",
    "    prompt_text = f'''请结合下面的资料，回答给定的问题：\n",
    "\n",
    "提问：{qtext}\n",
    "\n",
    "相关资料：{top1_content}\n",
    "'''\n",
    "\n",
    "    try:\n",
    "        answer = glm4air(prompt_text)\n",
    "        answer = answer.replace('\\n', '')\n",
    "        test_answer.append(answer)\n",
    "    except:\n",
    "        test_answer.append('无法回答。')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d0c95459-3a70-4244-9e0b-873f0042095a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_answer = pd.DataFrame(test_answer)\n",
    "test_answer.columns = ['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "04449ee1-67af-49ef-80f3-b7f44bbfb4fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_answer.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90743b90-6046-49fb-aff1-7e0dc3ae0abb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
